{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "2ca8c35e-d7c4-4eed-a69b-324a3ffdbea8",
   "metadata": {},
   "source": [
    "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/ourownstory/neural_prophet/blob/master/tutorials/feature-use/benchmarking.ipynb)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2f2af163-c378-4e01-8d5b-3def6194c29e",
   "metadata": {
    "tags": []
   },
   "source": [
    "# Running benchmarking experiments\n",
    "Note: The Benchmarking Framework does currently not properly support auto-regression or lagged covariates with multiple step ahead forecasts."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "80400b6d-ca57-47ba-9dc5-0da3885ab6b1",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "WARNING - (NP.benchmark.<module>) - Benchmarking Framework is not covered by tests. Please report any bugs you find.\n"
     ]
    }
   ],
   "source": [
    "if 'google.colab' in str(get_ipython()):\n",
    "    !pip install git+https://github.com/ourownstory/neural_prophet.git # may take a while\n",
    "    #!pip install neuralprophet # much faster, but may not have the latest upgrades/bugfixes\n",
    "    \n",
    "import pandas as pd\n",
    "from neuralprophet import NeuralProphet, set_log_level\n",
    "from neuralprophet.benchmark import Dataset, NeuralProphetModel, ProphetModel\n",
    "from neuralprophet.benchmark import SimpleBenchmark, CrossValidationBenchmark\n",
    "set_log_level(\"ERROR\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e5e79793-8ebb-4d06-b021-82b49d107653",
   "metadata": {},
   "source": [
    "## Load data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "930697ee-179c-4821-bbb7-c2f5e4588093",
   "metadata": {},
   "outputs": [],
   "source": [
    "data_location = \"https://raw.githubusercontent.com/ourownstory/neuralprophet-data/main/datasets/\"\n",
    "\n",
    "air_passengers_df = pd.read_csv(data_location + 'air_passengers.csv')\n",
    "peyton_manning_df = pd.read_csv(data_location + 'wp_log_peyton_manning.csv')\n",
    "# retail_sales_df = pd.read_csv(data_location + 'retail_sales.csv')\n",
    "# yosemite_temps_df = pd.read_csv(data_location +  'yosemite_temps.csv')\n",
    "# ercot_load_df = pd.read_csv(data_location +  'ERCOT_load.csv')[['ds', 'y']]"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "40f0b0c0-eeaa-4a1f-9251-84d41da1deae",
   "metadata": {
    "tags": []
   },
   "source": [
    "## 0. Configure Datasets and Model Parameters\n",
    "First, we define the datasets that we would like to benchmerk on.\n",
    "Next, we define the models that we want to evaluate and set their hyperparameters."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "ac375354-7755-48a7-b628-d322c0232f3c",
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset_list = [\n",
    "    Dataset(df = air_passengers_df, name = \"air_passengers\", freq = \"MS\"),\n",
    "    # Dataset(df = peyton_manning_df, name = \"peyton_manning\", freq = \"D\"),\n",
    "    # Dataset(df = retail_sales_df, name = \"retail_sales\", freq = \"D\"),\n",
    "    # Dataset(df = yosemite_temps_df, name = \"yosemite_temps\", freq = \"5min\"),\n",
    "    # Dataset(df = ercot_load_df, name = \"ercot_load\", freq = \"H\"),\n",
    "]\n",
    "model_classes_and_params = [\n",
    "    (NeuralProphetModel, {\"seasonality_mode\": \"multiplicative\", \"learning_rate\": 0.1}),\n",
    "    (ProphetModel, {\"seasonality_mode\": \"multiplicative\"})\n",
    "]"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "60e64a14-ab54-43b8-8115-1c6c02996e7c",
   "metadata": {},
   "source": [
    "Note: As all the classes used in the Benchmark framework are dataclasses, \n",
    "they have a print function, allowing us to peek into them if we like:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "fc5ac40e-8ce3-427f-be00-f2590dacc3e6",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[(neuralprophet.benchmark.NeuralProphetModel,\n",
       "  {'seasonality_mode': 'multiplicative', 'learning_rate': 0.1}),\n",
       " (neuralprophet.benchmark.ProphetModel,\n",
       "  {'seasonality_mode': 'multiplicative'})]"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model_classes_and_params"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4b924887-1728-4506-9275-92ec4be61033",
   "metadata": {},
   "source": [
    "## 1. SimpleBenchmark\n",
    "Setting up a series of Train Test Experiments is quick:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "27e909a2-cbc5-4703-a1e6-8963921b6364",
   "metadata": {},
   "outputs": [],
   "source": [
    "benchmark = SimpleBenchmark(\n",
    "    model_classes_and_params=model_classes_and_params, # iterate over this list of tuples\n",
    "    datasets=dataset_list, # iterate over this list\n",
    "    metrics=[\"MAE\", \"MSE\", \"MASE\", \"RMSE\"],\n",
    "    test_percentage=25,\n",
    ")\n",
    "results_train, results_test = benchmark.run()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "270a689e-47f5-45cb-9b2f-18196d149348",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>data</th>\n",
       "      <th>model</th>\n",
       "      <th>params</th>\n",
       "      <th>MAE</th>\n",
       "      <th>MSE</th>\n",
       "      <th>MASE</th>\n",
       "      <th>RMSE</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>air_passengers</td>\n",
       "      <td>NeuralProphet</td>\n",
       "      <td>{'seasonality_mode': 'multiplicative', 'learni...</td>\n",
       "      <td>24.814017</td>\n",
       "      <td>844.513472</td>\n",
       "      <td>0.571375</td>\n",
       "      <td>29.060514</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>air_passengers</td>\n",
       "      <td>Prophet</td>\n",
       "      <td>{'seasonality_mode': 'multiplicative'}</td>\n",
       "      <td>29.818648</td>\n",
       "      <td>1142.139138</td>\n",
       "      <td>0.686614</td>\n",
       "      <td>33.795549</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "             data          model  \\\n",
       "0  air_passengers  NeuralProphet   \n",
       "1  air_passengers        Prophet   \n",
       "\n",
       "                                              params        MAE          MSE  \\\n",
       "0  {'seasonality_mode': 'multiplicative', 'learni...  24.814017   844.513472   \n",
       "1             {'seasonality_mode': 'multiplicative'}  29.818648  1142.139138   \n",
       "\n",
       "       MASE       RMSE  \n",
       "0  0.571375  29.060514  \n",
       "1  0.686614  33.795549  "
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "results_test"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7977f0ce-d0c6-4d0c-9edc-a62862d1ea1e",
   "metadata": {},
   "source": [
    "## 2. CrossValidationBenchmark\n",
    "Setting up a series of crossvalidated experiments is just as simple:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "61a78380-0e77-4ecf-a356-4c89cdf52625",
   "metadata": {},
   "outputs": [],
   "source": [
    "benchmark_cv = CrossValidationBenchmark(\n",
    "    model_classes_and_params=model_classes_and_params, # iterate over this list of tuples\n",
    "    datasets=dataset_list, # iterate over this list\n",
    "    metrics=[\"MASE\", \"RMSE\"],\n",
    "    test_percentage=10,\n",
    "    num_folds=3,\n",
    "    fold_overlap_pct=0,\n",
    ")\n",
    "results_summary, results_train, results_test = benchmark_cv.run()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b2ca0381-7469-427f-b924-1269ff9f4c50",
   "metadata": {},
   "source": [
    "We now also get a summary DataFrame showing the metrics' mean and standard deviation over all folds."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "653ed48e-221d-48b2-81af-f1a172fe2beb",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>data</th>\n",
       "      <th>model</th>\n",
       "      <th>params</th>\n",
       "      <th>MASE</th>\n",
       "      <th>RMSE</th>\n",
       "      <th>MASE_std</th>\n",
       "      <th>RMSE_std</th>\n",
       "      <th>split</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>air_passengers</td>\n",
       "      <td>NeuralProphet</td>\n",
       "      <td>{'seasonality_mode': 'multiplicative', 'learni...</td>\n",
       "      <td>0.278579</td>\n",
       "      <td>7.664800</td>\n",
       "      <td>0.011388</td>\n",
       "      <td>0.696316</td>\n",
       "      <td>train</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>air_passengers</td>\n",
       "      <td>Prophet</td>\n",
       "      <td>{'seasonality_mode': 'multiplicative'}</td>\n",
       "      <td>0.310869</td>\n",
       "      <td>8.616463</td>\n",
       "      <td>0.021078</td>\n",
       "      <td>1.266764</td>\n",
       "      <td>train</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>air_passengers</td>\n",
       "      <td>NeuralProphet</td>\n",
       "      <td>{'seasonality_mode': 'multiplicative', 'learni...</td>\n",
       "      <td>0.482023</td>\n",
       "      <td>22.380878</td>\n",
       "      <td>0.140785</td>\n",
       "      <td>5.769275</td>\n",
       "      <td>test</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>air_passengers</td>\n",
       "      <td>Prophet</td>\n",
       "      <td>{'seasonality_mode': 'multiplicative'}</td>\n",
       "      <td>0.479454</td>\n",
       "      <td>22.778341</td>\n",
       "      <td>0.109195</td>\n",
       "      <td>4.224042</td>\n",
       "      <td>test</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "             data          model  \\\n",
       "0  air_passengers  NeuralProphet   \n",
       "1  air_passengers        Prophet   \n",
       "0  air_passengers  NeuralProphet   \n",
       "1  air_passengers        Prophet   \n",
       "\n",
       "                                              params      MASE       RMSE  \\\n",
       "0  {'seasonality_mode': 'multiplicative', 'learni...  0.278579   7.664800   \n",
       "1             {'seasonality_mode': 'multiplicative'}  0.310869   8.616463   \n",
       "0  {'seasonality_mode': 'multiplicative', 'learni...  0.482023  22.380878   \n",
       "1             {'seasonality_mode': 'multiplicative'}  0.479454  22.778341   \n",
       "\n",
       "   MASE_std  RMSE_std  split  \n",
       "0  0.011388  0.696316  train  \n",
       "1  0.021078  1.266764  train  \n",
       "0  0.140785  5.769275   test  \n",
       "1  0.109195  4.224042   test  "
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "results_summary"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "12697c82-16f0-4807-8ac5-66aa730fa0f6",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAbkAAAD4CAYAAABxJ5hVAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/MnkTPAAAACXBIWXMAAAsTAAALEwEAmpwYAAARaklEQVR4nO3df5BdZX3H8ffXJBACURoSTSDgrjGtBIIBAwIlGih0QrGAP/hhcUbBkTijQFNohxE62hnt4ECKpYA2/sKiEkREVCwTSk2JdiwkupBgSFWSlER+BgOoJCbx2z/u2XCz7CZ32b177z77fs3s5JznnHvO95y5k8885zznnshMJEkq0ataXYAkSc1iyEmSimXISZKKZchJkoplyEmSijW61QVoVxMnTsyOjo5WlyFJw8qKFSueycxJPdsNuTbT0dHB8uXLW12GJA0rEbG+t3YvV0qSimXISZKKZchJkorlPTlJGma2bdvGhg0b2LJlS6tLGXJjx45l6tSpjBkzpqH1DTlJGmY2bNjA+PHj6ejoICJaXc6QyUw2bdrEhg0b6OzsbOgzXq6UpGFmy5YtHHDAASMq4AAiggMOOKBfPVhDTpKGoZEWcN36e9yGnCSpWN6Tk6RhruPyuwZ1e+uuOm23y0eNGsXMmTPZvn07nZ2d3Hzzzey///6sW7eOzs5OrrjiCj75yU8C8MwzzzBlyhTmz5/P9ddfz5o1a5g/fz6bN29m69atzJkzh0WLFrF06VLOOOOMXe61XXPNNZx88skDOhZ7cpKkftlnn33o6upi1apVTJgwgRtuuGHnss7OTu6666XQve222zjssMN2zl988cUsWLCArq4uVq9ezUUXXbRz2Zw5c+jq6tr5N9CAA0NOkjQAxx13HBs3btw5P27cOA499NCdP0946623cvbZZ+9c/vjjjzN16tSd8zNnzmxqfYacJOkV2bFjB/feey+nn376Lu3nnnsuixcv5rHHHmPUqFEceOCBO5ctWLCAk046iVNPPZVrr72WzZs371y2bNkyZs2atfPvl7/85YBrNOQkSf3y4osvMmvWLCZPnsyTTz7JKaecssvyefPmcc8997B48WLOOeecXZadf/75rF69mrPOOoulS5dy7LHHsnXrVuDllyunTZs24FoNOUlSv3Tfk1u/fj2Zucs9OYC99tqLt7zlLSxcuJD3vOc9L/v8gQceyAUXXMCdd97J6NGjWbVqVdNqNeQkSa/IuHHjuO6661i4cCHbt2/fZdmll17Kpz/9aSZMmLBL+9133822bdsAeOKJJ9i0aRMHHXRQ02r0EQJJGub2NOS/mY488kiOOOIIbrnlFubMmbOz/bDDDttlVGW3JUuWcMkllzB27FgArr76aiZPnswjjzyy855ctyuvvLLXnmB/RGYOaAMaXLNnz05fmippd1avXs2hhx7a6jJaprfjj4gVmTm757perpQkFcuQkyQVy5CTpGFopN5q6u9xG3KSNMyMHTuWTZs2jbig636fXPeglUY4ulKShpmpU6eyYcMGnn766VaXMuS63wzeKENOkoaZMWPGNPxm7JHOy5WSpGIZcpKkYhlykqRieU+uzazc+Nygv+VXktpds36azJ6cJKlYhpwkqViGnCSpWIacJKlYhpwkqViGnCSpWIacJKlYhpwkqViGnCSpWIacJKlYhpwkqViGnCSpWIacJKlYhpwkqViGnCSpWIacJKlYhpwkqViGnCSpWIacJKlYhpwkqViGnCSpWIacJKlYhpwkqVhFh1xE7IiIrohYFRG3RcS4QdhmR0Ss6udnzoyIGQPdtySpf4oOOeDFzJyVmYcDvwc+XL8wIkYPUR1nAoacJA2x0kOu3jLgjRExNyKWRcR3gJ9FxNiI+HJErIyIn0bEiQAR8YGIuDMilkbEzyPi43XbGhURn4+IhyNiSUTsU31mWkTcHRErqn28KSKOB04Hrq56ldOG/MglaYQaqp5MS1U9tlOBu6umo4DDM3NtRFwKZGbOjIg3AUsi4o+r9Y4BDgd+BzwQEXcBzwDTgfdm5oci4hvAu4GvAouAD2fmzyPircCNmXlSFajfy8xv9lHfhcCFAKNePWnwT4AkjVClh9w+EdFVTS8DvggcD9yfmWur9hOAfwHIzEciYj3QHXL3ZOYmgIj4VrXut4G1mdm93RVAR0TsV237tojo3v/ejRSZmYuoBSR7T5me/T5KSVKvSg+5FzNzVn1DFUC/bfDzPQOne35rXdsOYB9ql34399yfJKl1RtI9ub4sA84DqC5THgKsqZadEhETqntuZwI/6msjmfk8sDYizqq2FRHx5mrxC8D45pQvSeqLIQc3Aq+KiJXArcAHMrO7p3Y/cDvwEHB7Zi7fw7bOAz4YEQ8CDwNnVO2Lgb+tBrY48ESShkjRlyszc79e2pYCS+vmtwDn97GJDZl5Zo/Pr6M2GKV7/pq66bXAvF72+SN8hECShpw9OUlSsYruyQ1EZt4E3NTiMiRJA2BPTpJULENOklQsQ06SVCxDTpJULENOklQsQ06SVCxDTpJULENOklQsQ06SVCxDTpJULENOklQsQ06SVCxDTpJULENOklQsQ06SVCxDTpJULENOklQsQ06SVCxDTpJUrNGtLkC7mnnQa1h+1WmtLkOSimBPTpJULENOklQsQ06SVCxDTpJULENOklQsQ06SVCxDTpJULENOklQsQ06SVCxDTpJULENOklQsQ06SVCxDTpJUrN2+hSAiJuxueWY+O7jlSJI0ePb0qp0VQALRy7IE3jDoFUmSNEh2G3KZ2TlUhUiSNNgauicXNe+LiL+v5g+JiGOaW5okSQPT6MCTG4HjgL+q5l8AbmhKRZIkDZI93ZPr9tbMPCoifgqQmb+OiL2aWJckSQPWaE9uW0SMojbYhIiYBPyhaVVJkjQIGg2564A7gNdGxKeAHwL/2LSqJEkaBA1drszMr0XECuDPqD1OcGZmrm5qZZIkDVB/HgZ/CrilfpkPg0uS2ll/HgY/BPh1Nb0/8H+Az9FJktrWbu/JZWZnZr4B+A/gLzNzYmYeALwDWDIUBUqS9Eo1OvDk2Mz8fvdMZv47cHxzSpIkaXA0+pzcryLiSuCr1fx5wK+aU5IkSYOj0Z7ce4FJ1B4juAN4bdUmSVLbavQRgmeBSyJifG02f9PcsiRJGrhGf6B5ZvWTXquAhyNiRUQc3tzSJEkamEYvV/4r8DeZ+frMfD1wKbCoeWVJkjRwjYbcvpn5g+6ZzFwK7NuUiiRJGiSNjq58tHqX3M3V/PuAR5tTkiRJg6PRntwF1EZX3l79TQTOb1ZRkiQNhkZDbhpwcLX+XtR+qPm+ZhUlSdJgaPRy5deAy6iNrvQ9cpKkYaHRkHs6M7/b1EokSRpkjYbcxyPiC8C9wNbuxsz8VlOqkiRpEDQacucDbwLG8NLlygQMOUlS22o05I7OzD9paiWSJA2yRkdX/ndEzGhqJZIkDbJGe3LHAl0RsZbaPbmg9kPNRzStMkmSBqjRkJvX1CokSWqCRl+1s77ZhUiSNNgavScnSdKwY8hJkoplyEmSimXISZKKZchJkoplyEmSimXISZKKZchJkoplyEmSitXoz3ppiKzc+Bwdl9/V6jIkqSnWXXXakO7PnpwkqViGnCSpWIacJKlYhpwkqViGnCSpWIacJKlYhpwkqViGnCSpWIacJKlYhpwkqViGnCSpWIacJKlYhpwkqViGnCSpWIacJKlYhpwkqViGnCSpWIacJKlYhpwkqViGnCSpWIacJKlYhpwkqViGnCSpWE0LuYjIiFhYN39ZRHyiWfur28/SiJhdTa+LiJUR8VBELImIyYO0j9/0c/25EXH8YOxbktS4ZvbktgLvioiJg7nRqOlP3Sdm5hHAcuBjA9zWKzUXMOQkaYg18z/47cAiYEHPBRExKSJuj4gHqr8/rdo/ERGX1a23KiI6qr81EfFvwCrg4Ij4bEQsj4iHI+IfGqjnPuCNfWzr6mpfKyPinGrfcyPivoi4q1r/c/WBGBGfiogHI+LHEfG6vo4rIjqADwMLIqIrIua8wvMpSeqnZvdibgDOi4jX9Gj/Z+DazDwaeDfwhQa2NR24MTMPy8z1wBWZORs4Anh7RByxh8+/A1jZc1vAbGAW8GbgZODqiJhSrXcMcBEwA5gGvKtq3xf4cWa+mVp4fqiv48rMdcDnqvZZmbmsZ2ERcWEV2Mt3/O65Bk6FJKkRo5u58cx8vuoxXQy8WLfoZGBGRHTPvzoi9tvD5tZn5o/r5s+OiAupHcMUakH0UC+f+0FE7KiWXQns32NbJwC3ZOYO4MmI+C/gaOB54P7MfBQgIm6p1v0m8Hvge9XnVwCnDOC4yMxF1Hq97D1leu5pfUlSY5oacpXPAD8BvlzX9irg2MzcUr9iRGxn197l2Lrp39at1wlcBhydmb+OiJt6rFvvxMx8pu6z+9dvaw96Bk73/LbM7J7ewUvnsa/janB3kqTB1PRBF5n5LPAN4IN1zUuoXQYEICJmVZPrgKOqtqOAzj42+2pqQfVcdT/s1AGUuAw4JyJGRcQk4G3A/dWyYyKis7oXdw7wwz1sq6/jegEYP4AaJUmvwFA9J7cQqB9leTEwuxra/zNqAzMAbgcmRMTDwEeB/+1tY5n5IPBT4BHg68CPBlDbHdQuZT4I/Cfwd5n5RLXsAeB6YDWwtlp3d/o6ru8C73TgiSQNrXjpqpvqRcRc4LLMfMdQ7nfvKdNzyvs/M5S7lKQhs+6q05qy3YhYUQ1G3IW/eCJJKtZQDDwZljJzKbC0xWVIkgbAnpwkqViGnCSpWIacJKlYhpwkqViGnCSpWIacJKlYhpwkqViGnCSpWIacJKlYhpwkqViGnCSpWIacJKlYhpwkqViGnCSpWIacJKlYhpwkqViGnCSpWIacJKlYhpwkqViGnCSpWKNbXYB2NfOg17D8qtNaXYYkFcGenCSpWIacJKlYhpwkqViGnCSpWIacJKlYhpwkqViGnCSpWIacJKlYhpwkqViGnCSpWIacJKlYhpwkqViGnCSpWIacJKlYhpwkqViGnCSpWIacJKlYhpwkqViGnCSpWIacJKlYhpwkqViGnCSpWIacJKlYhpwkqViGnCSpWIacJKlYkZmtrkF1IuIFYE2r62hDE4FnWl1EG/K89M7z0ruSz8vrM3NSz8bRrahEu7UmM2e3uoh2ExHLPS8v53npneeldyPxvHi5UpJULENOklQsQ679LGp1AW3K89I7z0vvPC+9G3HnxYEnkqRi2ZOTJBXLkJMkFcuQaxMRMS8i1kTELyLi8lbX0y4iYl1ErIyIrohY3up6WiUivhQRT0XEqrq2CRFxT0T8vPr3j1pZYyv0cV4+EREbq+9MV0T8RStrbIWIODgifhARP4uIhyPikqp9xH1nDLk2EBGjgBuAU4EZwHsjYkZrq2orJ2bmrJH2fE8PNwHzerRdDtybmdOBe6v5keYmXn5eAK6tvjOzMvP7Q1xTO9gOXJqZM4BjgY9U/6eMuO+MIdcejgF+kZmPZubvgcXAGS2uSW0kM+8Dnu3RfAbwlWr6K8CZQ1lTO+jjvIx4mfl4Zv6kmn4BWA0cxAj8zhhy7eEg4LG6+Q1VmyCBJRGxIiIubHUxbeZ1mfl4Nf0E8LpWFtNmPhoRD1WXM4u/JLc7EdEBHAn8DyPwO2PIqd2dkJlHUbuU+5GIeFurC2pHWXsWyOeBaj4LTANmAY8DC1taTQtFxH7A7cBfZ+bz9ctGynfGkGsPG4GD6+anVm0jXmZurP59CriD2qVd1TwZEVMAqn+fanE9bSEzn8zMHZn5B+DzjNDvTESMoRZwX8vMb1XNI+47Y8i1hweA6RHRGRF7AecC32lxTS0XEftGxPjuaeDPgVW7/9SI8h3g/dX0+4E7W1hL2+j+T7zyTkbgdyYiAvgisDoz/6lu0Yj7zviLJ22iGub8GWAU8KXM/FRrK2q9iHgDtd4b1N6Y8fWRel4i4hZgLrVXpTwJfBz4NvAN4BBgPXB2Zo6oQRh9nJe51C5VJrAOmF93H2pEiIgTgGXASuAPVfPHqN2XG1HfGUNOklQsL1dKkoplyEmSimXISZKKZchJkoplyEmSimXISZKKZchJkor1/8hJW721kiciAAAAAElFTkSuQmCC\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "air_passengers = results_summary[results_summary['data'] == 'air_passengers']\n",
    "air_passengers = air_passengers[air_passengers['split'] == 'test']\n",
    "plt = air_passengers.plot(x='model', y='RMSE', kind='barh')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0b32fd49-55c6-4cd4-9082-3ee073cc741e",
   "metadata": {},
   "source": [
    "The metrics for each fold are also recoreded individually:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "5bdf240d-a6a8-4433-b12f-3fa3921d8b7a",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>data</th>\n",
       "      <th>model</th>\n",
       "      <th>params</th>\n",
       "      <th>MASE</th>\n",
       "      <th>RMSE</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>air_passengers</td>\n",
       "      <td>NeuralProphet</td>\n",
       "      <td>{'seasonality_mode': 'multiplicative', 'learni...</td>\n",
       "      <td>[0.5014669235699923, 0.3006991286697633, 0.643...</td>\n",
       "      <td>[20.975299319246627, 16.12341769010618, 30.043...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>air_passengers</td>\n",
       "      <td>Prophet</td>\n",
       "      <td>{'seasonality_mode': 'multiplicative'}</td>\n",
       "      <td>[0.5885898654326461, 0.3302688815416577, 0.519...</td>\n",
       "      <td>[24.617703272976527, 16.936631879241148, 26.78...</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "             data          model  \\\n",
       "0  air_passengers  NeuralProphet   \n",
       "1  air_passengers        Prophet   \n",
       "\n",
       "                                              params  \\\n",
       "0  {'seasonality_mode': 'multiplicative', 'learni...   \n",
       "1             {'seasonality_mode': 'multiplicative'}   \n",
       "\n",
       "                                                MASE  \\\n",
       "0  [0.5014669235699923, 0.3006991286697633, 0.643...   \n",
       "1  [0.5885898654326461, 0.3302688815416577, 0.519...   \n",
       "\n",
       "                                                RMSE  \n",
       "0  [20.975299319246627, 16.12341769010618, 30.043...  \n",
       "1  [24.617703272976527, 16.936631879241148, 26.78...  "
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "results_test"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "957b9a8e-f9ed-4b56-9ec2-eede0d3b6d04",
   "metadata": {},
   "source": [
    "## 3. Manual Benchmark\n",
    "If you need more control over the individual Experiments, you can set them up manually:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "184d06bf-d479-4f7e-a476-236ab51ccf0e",
   "metadata": {},
   "outputs": [],
   "source": [
    "from neuralprophet.benchmark import SimpleExperiment, CrossValidationExperiment\n",
    "from neuralprophet.benchmark import ManualBenchmark, ManualCVBenchmark"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "954116e3-4078-41d3-b7cd-c8888b8c03bb",
   "metadata": {},
   "source": [
    "### 3.1 ManualBenchmark: Manual SimpleExperiment Benchmark"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7ee7c717-dc96-460a-ad01-941385d038f9",
   "metadata": {},
   "outputs": [],
   "source": [
    "air_passengers_df = pd.read_csv(data_location + 'air_passengers.csv')\n",
    "peyton_manning_df = pd.read_csv(data_location + 'wp_log_peyton_manning.csv')\n",
    "metrics = [\"MAE\", \"MSE\", \"RMSE\", \"MASE\", \"MSSE\", \"MAPE\", \"SMAPE\"]\n",
    "experiments = [\n",
    "    SimpleExperiment(\n",
    "        model_class=NeuralProphetModel,\n",
    "        params={\"seasonality_mode\": \"multiplicative\", \"learning_rate\": 0.1},\n",
    "        data=Dataset(df=air_passengers_df, name=\"air_passengers\", freq=\"MS\"),\n",
    "        metrics=metrics,\n",
    "        test_percentage=25,\n",
    "    ),\n",
    "    SimpleExperiment(\n",
    "        model_class=ProphetModel,\n",
    "        params={\"seasonality_mode\": \"multiplicative\", },\n",
    "        data=Dataset(df=air_passengers_df, name=\"air_passengers\", freq=\"MS\"),\n",
    "        metrics=metrics,\n",
    "        test_percentage=25,\n",
    "    ),\n",
    "    SimpleExperiment(\n",
    "        model_class=NeuralProphetModel,\n",
    "        params={\"learning_rate\": 0.1},\n",
    "        data=Dataset(df=peyton_manning_df, name=\"peyton_manning\", freq=\"D\"),\n",
    "        metrics=metrics,\n",
    "        test_percentage=15,\n",
    "    ),\n",
    "    SimpleExperiment(\n",
    "        model_class=ProphetModel,\n",
    "        params={},\n",
    "        data=Dataset(df=peyton_manning_df, name=\"peyton_manning\", freq=\"D\"),\n",
    "        metrics=metrics,\n",
    "        test_percentage=15,\n",
    "    ),\n",
    "]\n",
    "benchmark = ManualBenchmark(\n",
    "    experiments=experiments,\n",
    "    metrics=metrics,\n",
    ")\n",
    "results_train, results_test = benchmark.run()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "1a8c8ef4-d26a-4aa9-bfcc-b8134eab4daf",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>data</th>\n",
       "      <th>model</th>\n",
       "      <th>params</th>\n",
       "      <th>MAE</th>\n",
       "      <th>MSE</th>\n",
       "      <th>RMSE</th>\n",
       "      <th>MASE</th>\n",
       "      <th>MSSE</th>\n",
       "      <th>MAPE</th>\n",
       "      <th>SMAPE</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>air_passengers</td>\n",
       "      <td>NeuralProphet</td>\n",
       "      <td>{'seasonality_mode': 'multiplicative', 'learni...</td>\n",
       "      <td>23.922920</td>\n",
       "      <td>795.467295</td>\n",
       "      <td>28.204030</td>\n",
       "      <td>0.550857</td>\n",
       "      <td>0.305727</td>\n",
       "      <td>5.749847</td>\n",
       "      <td>2.766594</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>air_passengers</td>\n",
       "      <td>Prophet</td>\n",
       "      <td>{'seasonality_mode': 'multiplicative'}</td>\n",
       "      <td>29.818648</td>\n",
       "      <td>1142.139138</td>\n",
       "      <td>33.795549</td>\n",
       "      <td>0.686614</td>\n",
       "      <td>0.438966</td>\n",
       "      <td>7.471930</td>\n",
       "      <td>3.558548</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>peyton_manning</td>\n",
       "      <td>NeuralProphet</td>\n",
       "      <td>{'learning_rate': 0.1}</td>\n",
       "      <td>0.544561</td>\n",
       "      <td>0.409436</td>\n",
       "      <td>0.639872</td>\n",
       "      <td>1.542943</td>\n",
       "      <td>1.537802</td>\n",
       "      <td>7.006563</td>\n",
       "      <td>3.374985</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>peyton_manning</td>\n",
       "      <td>Prophet</td>\n",
       "      <td>{}</td>\n",
       "      <td>0.594528</td>\n",
       "      <td>0.463643</td>\n",
       "      <td>0.680913</td>\n",
       "      <td>1.684520</td>\n",
       "      <td>1.741397</td>\n",
       "      <td>7.673804</td>\n",
       "      <td>3.682554</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "             data          model  \\\n",
       "0  air_passengers  NeuralProphet   \n",
       "1  air_passengers        Prophet   \n",
       "2  peyton_manning  NeuralProphet   \n",
       "3  peyton_manning        Prophet   \n",
       "\n",
       "                                              params        MAE          MSE  \\\n",
       "0  {'seasonality_mode': 'multiplicative', 'learni...  23.922920   795.467295   \n",
       "1             {'seasonality_mode': 'multiplicative'}  29.818648  1142.139138   \n",
       "2                             {'learning_rate': 0.1}   0.544561     0.409436   \n",
       "3                                                 {}   0.594528     0.463643   \n",
       "\n",
       "        RMSE      MASE      MSSE      MAPE     SMAPE  \n",
       "0  28.204030  0.550857  0.305727  5.749847  2.766594  \n",
       "1  33.795549  0.686614  0.438966  7.471930  3.558548  \n",
       "2   0.639872  1.542943  1.537802  7.006563  3.374985  \n",
       "3   0.680913  1.684520  1.741397  7.673804  3.682554  "
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "results_test"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ee0214e5-ee8a-452d-baa3-c390027f9cf8",
   "metadata": {},
   "source": [
    "### 3.2 ManualCVBenchmark: Manual CrossValidationExperiment Benchmark"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b73463e4-75f0-4e81-a138-f834407f93b1",
   "metadata": {},
   "outputs": [],
   "source": [
    "air_passengers_df = pd.read_csv(data_location + 'air_passengers.csv')\n",
    "experiments = [\n",
    "    CrossValidationExperiment(\n",
    "        model_class=NeuralProphetModel,\n",
    "        params={\"seasonality_mode\": \"multiplicative\", \"learning_rate\": 0.1},\n",
    "        data=Dataset(df=air_passengers_df, name=\"air_passengers\", freq=\"MS\"),\n",
    "        metrics=metrics,\n",
    "        test_percentage=10,\n",
    "        num_folds=3,\n",
    "        fold_overlap_pct=0,\n",
    "    ),\n",
    "    CrossValidationExperiment(\n",
    "        model_class=ProphetModel,\n",
    "        params={\"seasonality_mode\": \"multiplicative\", },\n",
    "        data=Dataset(df=air_passengers_df, name=\"air_passengers\", freq=\"MS\"),\n",
    "        metrics=metrics,\n",
    "        test_percentage=10,\n",
    "        num_folds=3,\n",
    "        fold_overlap_pct=0,\n",
    "    ),\n",
    "]\n",
    "benchmark_cv = ManualCVBenchmark(\n",
    "    experiments=experiments,\n",
    "    metrics=metrics,\n",
    ")\n",
    "results_summary, results_train, results_test = benchmark_cv.run()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "42dbdf41-4c39-4ed8-a42c-1ca52dba0938",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>data</th>\n",
       "      <th>model</th>\n",
       "      <th>params</th>\n",
       "      <th>MAE</th>\n",
       "      <th>MSE</th>\n",
       "      <th>RMSE</th>\n",
       "      <th>MASE</th>\n",
       "      <th>MSSE</th>\n",
       "      <th>MAPE</th>\n",
       "      <th>SMAPE</th>\n",
       "      <th>MAE_std</th>\n",
       "      <th>MSE_std</th>\n",
       "      <th>RMSE_std</th>\n",
       "      <th>MASE_std</th>\n",
       "      <th>MSSE_std</th>\n",
       "      <th>MAPE_std</th>\n",
       "      <th>SMAPE_std</th>\n",
       "      <th>split</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>air_passengers</td>\n",
       "      <td>NeuralProphet</td>\n",
       "      <td>{'seasonality_mode': 'multiplicative', 'learni...</td>\n",
       "      <td>6.016460</td>\n",
       "      <td>61.312444</td>\n",
       "      <td>7.806123</td>\n",
       "      <td>0.282338</td>\n",
       "      <td>0.081969</td>\n",
       "      <td>3.053079</td>\n",
       "      <td>1.512873</td>\n",
       "      <td>0.573447</td>\n",
       "      <td>9.319967</td>\n",
       "      <td>0.613913</td>\n",
       "      <td>0.013311</td>\n",
       "      <td>0.009355</td>\n",
       "      <td>0.106123</td>\n",
       "      <td>0.053336</td>\n",
       "      <td>train</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>air_passengers</td>\n",
       "      <td>Prophet</td>\n",
       "      <td>{'seasonality_mode': 'multiplicative'}</td>\n",
       "      <td>6.655735</td>\n",
       "      <td>75.848123</td>\n",
       "      <td>8.616463</td>\n",
       "      <td>0.310869</td>\n",
       "      <td>0.098696</td>\n",
       "      <td>3.089578</td>\n",
       "      <td>1.553327</td>\n",
       "      <td>0.952939</td>\n",
       "      <td>20.968356</td>\n",
       "      <td>1.266764</td>\n",
       "      <td>0.021078</td>\n",
       "      <td>0.014689</td>\n",
       "      <td>0.261419</td>\n",
       "      <td>0.132790</td>\n",
       "      <td>train</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>air_passengers</td>\n",
       "      <td>NeuralProphet</td>\n",
       "      <td>{'seasonality_mode': 'multiplicative', 'learni...</td>\n",
       "      <td>19.514176</td>\n",
       "      <td>528.803367</td>\n",
       "      <td>22.342478</td>\n",
       "      <td>0.482325</td>\n",
       "      <td>0.233347</td>\n",
       "      <td>4.630512</td>\n",
       "      <td>2.340100</td>\n",
       "      <td>6.610696</td>\n",
       "      <td>252.724588</td>\n",
       "      <td>5.442155</td>\n",
       "      <td>0.135160</td>\n",
       "      <td>0.089620</td>\n",
       "      <td>1.386856</td>\n",
       "      <td>0.732939</td>\n",
       "      <td>test</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>air_passengers</td>\n",
       "      <td>Prophet</td>\n",
       "      <td>{'seasonality_mode': 'multiplicative'}</td>\n",
       "      <td>19.052098</td>\n",
       "      <td>536.695336</td>\n",
       "      <td>22.778341</td>\n",
       "      <td>0.479454</td>\n",
       "      <td>0.249282</td>\n",
       "      <td>4.604149</td>\n",
       "      <td>2.272174</td>\n",
       "      <td>3.876074</td>\n",
       "      <td>182.404522</td>\n",
       "      <td>4.224042</td>\n",
       "      <td>0.109195</td>\n",
       "      <td>0.103208</td>\n",
       "      <td>0.710556</td>\n",
       "      <td>0.353903</td>\n",
       "      <td>test</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "             data          model  \\\n",
       "0  air_passengers  NeuralProphet   \n",
       "1  air_passengers        Prophet   \n",
       "0  air_passengers  NeuralProphet   \n",
       "1  air_passengers        Prophet   \n",
       "\n",
       "                                              params        MAE         MSE  \\\n",
       "0  {'seasonality_mode': 'multiplicative', 'learni...   6.016460   61.312444   \n",
       "1             {'seasonality_mode': 'multiplicative'}   6.655735   75.848123   \n",
       "0  {'seasonality_mode': 'multiplicative', 'learni...  19.514176  528.803367   \n",
       "1             {'seasonality_mode': 'multiplicative'}  19.052098  536.695336   \n",
       "\n",
       "        RMSE      MASE      MSSE      MAPE     SMAPE   MAE_std     MSE_std  \\\n",
       "0   7.806123  0.282338  0.081969  3.053079  1.512873  0.573447    9.319967   \n",
       "1   8.616463  0.310869  0.098696  3.089578  1.553327  0.952939   20.968356   \n",
       "0  22.342478  0.482325  0.233347  4.630512  2.340100  6.610696  252.724588   \n",
       "1  22.778341  0.479454  0.249282  4.604149  2.272174  3.876074  182.404522   \n",
       "\n",
       "   RMSE_std  MASE_std  MSSE_std  MAPE_std  SMAPE_std  split  \n",
       "0  0.613913  0.013311  0.009355  0.106123   0.053336  train  \n",
       "1  1.266764  0.021078  0.014689  0.261419   0.132790  train  \n",
       "0  5.442155  0.135160  0.089620  1.386856   0.732939   test  \n",
       "1  4.224042  0.109195  0.103208  0.710556   0.353903   test  "
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "results_summary"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "51c01c92-9cd4-47ed-9b53-60ac1b6caf98",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "np-dev",
   "language": "python",
   "name": "np-dev"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
